from continual_rl.policies.arc.arc_monobeast import ARCMonobeast
from continual_rl.policies.arc.arc_policy_config import ARCPolicyConfig
from continual_rl.policies.impala.impala_policy import ImpalaPolicy

class ARCPolicy(ImpalaPolicy):
    # 对于impalaPolicy, 子类只需要包装Monobeast和policyConfig
    def __init__(self, config: ARCPolicyConfig, observation_spaces, action_spaces, impala_class: ARCMonobeast = None,
                 policy_net_class=None):
        if impala_class is None:
            impala_class = ARCMonobeast

        super().__init__(config, observation_spaces, action_spaces, impala_class=impala_class,
                         policy_net_class=policy_net_class)

    def task_change(self, task_run_id=0):
        super().task_change(task_run_id)
        self.impala_trainer.task_change(task_run_id)
